iT邦幫忙

2022 iThome 鐵人賽

0
AI & Data

JAX 好好玩系列 第 33

JAX 好好玩 (33) : 類別與 jit (1) : 重新定義 hash

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載 )

如果一個函式,其「具有 class 類別型態」的輸入參數,我們可以用 jit 修飾它嗎?先看一個例子:

# 定義一個 user-defined class
# ====================================================================================
class MyClass01():
    def __init__(self, x=1.0, y=1.0):
        self.x = x
        self.y = y
 
# 定義 function
# ====================================================================================
@jax.jit
def my_func_01(cls: MyClass01, addition):
    return cls.x + cls.y + addition
 
my_class = MyClass01(1.0, 2.0)
 
try:
    my_func_01(my_class, 3.0)
except TypeError as e:
    print(f"There's a TypeError")
    print(e)

output :
https://ithelp.ithome.com.tw/upload/images/20221012/20129616iLSEPj8vav.png

從以上的程式片段可以知道,使用者自訂的 class 型態是不相容於 jit 的。當使用 jit 來編譯這一類的函式時,會產生 TypeError 例外。

第一個解法是使用「偏函式的靜態參數」宣告,老頭在先前的貼文中提到過這個方法:

# 將類別型態的輸入參數宣告為 static
# ====================================================================================
@partial(jax.jit, static_argnums=0)
def my_func_01(cls: MyClass01, addition):
    return cls.x + cls.y + addition
 
my_func_01(my_class, 3.0)

output :
DeviceArray(6., dtype=float32, weak_type=True)

這種解法有一個嚴重的限制!不能隨意修改這個 class 類別的參數值!以下面這段程式來說明:

my_class = MyClass01(1.0, 2.0)
print(f'Before Modification: {my_func_01(my_class, 3.0)}')
 
# 修改 my_class 的內容
my_class.x = 3.0
my_class.y = 4.0
print(f'After Modification: {my_func_01(my_class, 3.0)}')

output :
Before Modification: 6.0
After Modification: 6.0

class 類別值的修改,並不會反應在第二次呼叫上!為什麼呢?

在第一次呼叫時,JAX/JIT 將第一個參數 my_class 視為常數,將其編入可執行碼裏,執行後,把這個可執行碼暫存起來。

在第二次呼叫時,JAX/JIT 並沒有辦法發覺 my_class 內容已經被修改過了,而是認為它和第一次呼叫的 my_class 是一樣的,因此 JAX 就直接執行剛才暫存的執行碼。

好,大家一定接著想問,為什麼 JAX 會認為第一次呼叫時的 my_class 和第二次呼叫時的 my_class 是一樣的?

原因在於 Python 中 class 型別的 hash 機制 ! 請看下列程式段:

my_class = MyClass01(1.0, 2.0)
print(f'hash before modification: {hash(my_class)}')
 
my_class.x = 3.0
my_class.y = 4.0
print(f'hash after modification: {hash(my_class)}')

output :
hash before modification: 8728844717001
hash after modification: 8728844717001

在 my_class 修改前和修改後 hash(my_class) 的值都是一樣的,Python 既定的 class 型別的 hash() 算法,並不會把 class attribute 的值列入考量。而不巧的是,JAX 就是利用 hash() 來判斷 my_class 是不是相同。

因此,完整的解法是要重新定義 my_class 計算 hash 的方法。

# 定義一個 user-defined class, 並定義其 __hash__ 和 __eq__
# ====================================================================================
class MyClass02():
    def __init__(self, x=1.0, y=1.0):
        self.x = x
        self.y = y
 
    def __hash__(self):
        return hash((self.x, self.y))
 
    def __eq__(self, other):
        return (isinstance(other, MyClass02)) and\
               (self.x, self.y) == (other.x, other.y)
 
@partial(jax.jit, static_argnums=0)
def my_func_02(cls: MyClass02, addition):
    return cls.x + cls.y + addition
 
my_class02 = MyClass02(1.0, 2.0)
 
print(f'Before Modification: {my_func_02(my_class02, 3.0)}')
print(f'                     {hash(my_class02)}')
 
# 修改 my_class 的內容
my_class02.x = 3.0
my_class02.y = 4.0
print(f'After Modification: {my_func_02(my_class02, 3.0)}')
print(f'                    {hash(my_class02)}')

output :
https://ithelp.ithome.com.tw/upload/images/20221012/20129616SNXqhKw3Tv.png

要注意的是,當我們修改 class 中的 hash() 算法時,也要同時修正 eq() 的算法,以確保兩者的語義保持一致。這一部份老頭就不多著墨,有興趣的讀者可以去參考 Python 的 Data Model [33.1]。

參考:

[33.1] 可參考 Python Data Model: hash


上一篇
JAX 好好玩 (32) : 綜合演練 – 預測 MNIST
下一篇
JAX 好好玩 (34) : 類別與 jit (2) : 註冊類別為 pytree
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言